import torch
from model.zhnnet import ZhnNet

print('Export onnx.')
model = ZhnNet()
model.load_state_dict(torch.load('zhnnet.pth'))
dummy_input = torch.randn(2, 3, 480, 640)
torch.onnx.export(model, dummy_input, "zhnnet.onnx", opset_version=12)
print('onnx saved.')
